import torch
import os
import cv2
import matplotlib.pyplot as plt
from torchvision.transforms import functional as F
import torchvision
from PIL import Image
from xml.dom.minidom import parse
import utils
import transforms as T
from engine import train_one_epoch, evaluate
import xml.etree.cElementTree as ET
import collections
import pandas as pd
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
import sys
import torch
from torchvision import transforms
from PIL import Image
import numpy as np
import cv2

# 把要显示的数据集放在这个文件夹里
root = r'dataset/train'
# root = r'dataset'
# root = r'dataset\train'
img_path = os.path.join(root, r'JPEGImages')
xml_path = os.path.join(root,  r'Annotations')
img_list = os.listdir(img_path)
xml_list = os.listdir(xml_path)

print(img_list)
print(xml_list)

#
with open(os.path.join(root + "../..", "label_list.txt"), 'r') as file:
    label_list = file.readlines()
# map(str.rstrip, label_list)  # 去掉末尾的\n  # map中，传进去一个函数，而不是传进去一个函数的返回值
label_list = [label.rstrip() for label in label_list]  # 去掉空字符
label_list = [label for label in label_list if label != '']  # 去掉空行
# print(label_list)

# 为保证参数的一致性，先获得完整的img和target
for i in range(len(img_list)):
    print(os.path.join(root, img_list[i]))
    print(os.path.join(root, xml_list[i]))
    # img = Image.open(os.path.join(root, 'JPEGImages', img_list[i])).convert("RGB")
    img = cv2.imread(os.path.join(root, 'JPEGImages', img_list[i]))
    bbox_xml_path = os.path.join(root, 'Annotations', xml_list[i])
    dom = parse(bbox_xml_path)
    data = dom.documentElement
    objects = data.getElementsByTagName('object')
    boxes = []
    labels = []
    for object_ in objects:
        # name就是label字符串
        name = object_.getElementsByTagName('name')[0].childNodes[0].nodeValue  # 就是label
        labels.append(label_list.index(name))

        # 返回的应该是一个列表，但是这里只有一个bndbox，但是仍然要用下标0来获得第一个的对象
        bndbox = object_.getElementsByTagName('bndbox')[0]
        xmin = np.float(bndbox.getElementsByTagName('xmin')[0].childNodes[0].nodeValue)
        ymin = np.float(bndbox.getElementsByTagName('ymin')[0].childNodes[0].nodeValue)
        xmax = np.float(bndbox.getElementsByTagName('xmax')[0].childNodes[0].nodeValue)
        ymax = np.float(bndbox.getElementsByTagName('ymax')[0].childNodes[0].nodeValue)
        # 列表汇总，附加一个四个数的列表
        boxes.append([xmin, ymin, xmax, ymax])

        # 例程中要求，将target中所有的东西转换为tensor
        # 转换为tensor
    boxes = torch.as_tensor(boxes, dtype=torch.float32)  # boxes的dtype必须是flaot
    labels = torch.as_tensor(labels, dtype=torch.int64)  # labels的dtype必须是int64


    # 制作数据集的target部分，这个部分包括以下几样东西
    # target是一个字典，里面至少有boxes和labels关键字，这里，剩下的三个是为了evaluate
    target = {"boxes": boxes,
              "labels": labels}
    # 模仿train的模式，先做一个to_tensor,增强的出入img都是torch.Tensor形式
    img = F.to_tensor(img)
    # 这里把transform实例化，获得增强后的img和target
    # cutout = T.Cutout(n_holes=1, length_w=100,length_h=10)      # 遮挡
    # img_transform, target_transform = cutout(img, target)
    expand = T.RandomCrop(prob=1.0, threshold=1.4)              # 修改这个函数，就可以可视化别的变换
    img_transform, target_transform = expand(img, target)
    # randomcrop = T.RandomCrop(prob=1)
    # img_transform, target_transform = randomcrop(img, target)
    # 对img和target做tensor到numpy的转化，才能在opencv中显示
    img_transform = np.transpose(img_transform.numpy(), (1, 2, 0))  # 这里把tensor类型的image转化成numpy形式，tensor和numpy的h，w是相反的
    img_transform = np.ascontiguousarray(img_transform)             # 使其内存连续，才能cv2.rectangle
    img_transform = cv2.cvtColor(img_transform, cv2.COLOR_RGB2BGR)
    # print(type(target_transform["boxes"]))
    target_transform["boxes"] = target_transform["boxes"].numpy()   # 同理，tensor2numpy
    target_transform["labels"] = target_transform["labels"].numpy()
    # print(target["labels"][0])
    # print(target["labels"][0])
    # for object_ in objects:
        # # name就是label字符串
        # name = object_.getElementsByTagName('name')[0].childNodes[0].nodeValue  # 就是label
        # print(name)
        #
        # # 返回的应该是一个列表，但是这里只有一个bndbox，但是仍然要用下标0来获得第一个的对象
        # bndbox = object_.getElementsByTagName('bndbox')[0]
        # xmin = np.float(bndbox.getElementsByTagName('xmin')[0].childNodes[0].nodeValue)
        # ymin = np.float(bndbox.getElementsByTagName('ymin')[0].childNodes[0].nodeValue)
        # xmax = np.float(bndbox.getElementsByTagName('xmax')[0].childNodes[0].nodeValue)
        # ymax = np.float(bndbox.getElementsByTagName('ymax')[0].childNodes[0].nodeValue)
        # cv2.rectangle(img_cut, (int(xmin), int(ymin)), (int(xmax), int(ymax)), (255, 0, 0), thickness=2)
        # cv2.putText(img_cut, name, (int(xmin), int(ymin)), fontFace=cv2.FONT_HERSHEY_SIMPLEX, fontScale=1,
        #             color=(0, 255, 0))
    # cv2.imshow("img", img)
    # print(target_cut)

    for label in range(len(target["labels"])):
        # name就是label字符串

        name = label_list[target["labels"][label]]      # target内存放的是label_list的index
        cv2.rectangle(img_transform, (int(target_transform["boxes"][label][0]), int(target_transform["boxes"][label][1])), (int(target_transform["boxes"][label][2]), int(target_transform["boxes"][label][3])),
                      (255, 0, 0), thickness=2)
        cv2.putText(img_transform, name, (int(target_transform["boxes"][label][0]), int(target_transform["boxes"][label][1])), fontFace=cv2.FONT_HERSHEY_SIMPLEX, fontScale=1,
                    color=(0, 255, 0))
    cv2.imshow("img_transform", img_transform)      #imshow必须是一个mat形式
    key = cv2.waitKey()
    # 按q退出
    if key == ord('q') or key == 27:
        break

cv2.destroyAllWindows()
